{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('/home/yuke/PythonProject/DrugEmbedding/')\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tnrange\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "from decode import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def recon_acc_score(configs, model, smiles_sample_lst):\n",
    "    match_lst = []\n",
    "    for i in tnrange(len(smiles_sample_lst)):\n",
    "        smiles_x = smiles_sample_lst[i]\n",
    "        mean, logv = smiles2mean(configs, smiles_x, model)\n",
    "        _, _, smiles_lst = latent2smiles(configs, model, z=mean.repeat(200, 1),\n",
    "                                                   nsamples=1, sampling_mode='random')\n",
    "        if smiles_x in smiles_lst:\n",
    "            match_lst.append(1)\n",
    "        else:\n",
    "            match_lst.append(0)\n",
    "    return np.array(match_lst).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare Model Reference Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_size_lst_64 = [64, 64, 64, 64, 64, 64]\n",
    "manifold_lst_64 = ['Euclidean', 'Euclidean', 'Euclidean', 'Lorentz', 'Lorentz', 'Lorentz']\n",
    "exp_dir_lst_64 = ['./experiments/KDD/kdd_009', './experiments/KDD_SEED/kdd_e64_s1', './experiments/KDD_SEED/kdd_e64_s2',\n",
    "              './experiments/KDD/kdd_010', './experiments/KDD_SEED/kdd_l64_s1', './experiments/KDD_SEED/kdd_l64_s2']\n",
    "checkpoint_lst_64 = ['checkpoint_epoch110.model', 'checkpoint_epoch120.model', 'checkpoint_epoch120.model',\n",
    "                 'checkpoint_epoch110.model', 'checkpoint_epoch120.model', 'checkpoint_epoch120.model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_size_lst_32 = [32, 32, 32, 32, 32, 32]\n",
    "manifold_lst_32 = ['Euclidean', 'Euclidean', 'Euclidean', 'Lorentz', 'Lorentz', 'Lorentz']\n",
    "exp_dir_lst_32 = ['./experiments/KDD/kdd_015', './experiments/KDD_SEED/kdd_e32_s1', './experiments/KDD_SEED/kdd_e32_s2',\n",
    "              './experiments/KDD/kdd_016', './experiments/KDD_SEED/kdd_l32_s1', './experiments/KDD_SEED/kdd_l32_s2']\n",
    "checkpoint_lst_32 = ['checkpoint_epoch110.model', 'checkpoint_epoch130.model', 'checkpoint_epoch110.model',\n",
    "                 'checkpoint_epoch110.model', 'checkpoint_epoch130.model', 'checkpoint_epoch110.model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_size_lst_8 = [8, 8, 8, 8, 8, 8]\n",
    "manifold_lst_8 = ['Euclidean', 'Euclidean', 'Euclidean', 'Lorentz', 'Lorentz', 'Lorentz']\n",
    "exp_dir_lst_8 = ['./experiments/KDD/kdd_017', './experiments/KDD_SEED/kdd_e8_s1', './experiments/KDD_SEED/kdd_e8_s2',\n",
    "              './experiments/KDD/kdd_018', './experiments/KDD_SEED/kdd_l8_s1', './experiments/KDD_SEED/kdd_l8_s2']\n",
    "checkpoint_lst_8 = ['checkpoint_epoch110.model', 'checkpoint_epoch120.model', 'checkpoint_epoch110.model',\n",
    "                 'checkpoint_epoch120.model', 'checkpoint_epoch120.model', 'checkpoint_epoch110.model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_size_lst_4 = [4, 4, 4, 4, 4, 4]\n",
    "manifold_lst_4 = ['Euclidean', 'Euclidean', 'Euclidean', 'Lorentz', 'Lorentz', 'Lorentz']\n",
    "exp_dir_lst_4 = ['./experiments/KDD/kdd_019', './experiments/KDD_SEED/kdd_e4_s1', './experiments/KDD_SEED/kdd_e4_s2',\n",
    "              './experiments/KDD/kdd_020', './experiments/KDD_SEED/kdd_l4_s1', './experiments/KDD_SEED/kdd_l4_s2']\n",
    "checkpoint_lst_4 = ['checkpoint_epoch100.model', 'checkpoint_epoch110.model', 'checkpoint_epoch110.model',\n",
    "                 'checkpoint_epoch100.model', 'checkpoint_epoch100.model', 'checkpoint_epoch090.model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_size_lst_2 = [2, 2, 2, 2, 2, 2]\n",
    "manifold_lst_2 = ['Euclidean', 'Euclidean', 'Euclidean', 'Lorentz', 'Lorentz', 'Lorentz']\n",
    "exp_dir_lst_2 = ['./experiments/KDD/kdd_021', './experiments/KDD_SEED/kdd_e2_s1', './experiments/KDD_SEED/kdd_e2_s2',\n",
    "              './experiments/KDD/kdd_022', './experiments/KDD_SEED/kdd_l2_s1', './experiments/KDD_SEED/kdd_l2_s2']\n",
    "checkpoint_lst_2 = ['checkpoint_epoch110.model', 'checkpoint_epoch100.model', 'checkpoint_epoch090.model',\n",
    "                 'checkpoint_epoch110.model', 'checkpoint_epoch100.model', 'checkpoint_epoch090.model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_size_lst = latent_size_lst_64 + latent_size_lst_32 + latent_size_lst_8 + latent_size_lst_4 + latent_size_lst_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "manifold_lst = manifold_lst_64 + manifold_lst_32 + manifold_lst_8 + manifold_lst_4 + manifold_lst_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir_lst = exp_dir_lst_64 + exp_dir_lst_32 + exp_dir_lst_8 + exp_dir_lst_4 + exp_dir_lst_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_lst = checkpoint_lst_64 + checkpoint_lst_32 + checkpoint_lst_8 + checkpoint_lst_4 + checkpoint_lst_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame.from_dict({'latent_size': latent_size_lst, 'manifold': manifold_lst,\n",
    "                      'exp_dir': exp_dir_lst, 'checkpoint': checkpoint_lst})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('./experiments/RECON/model_dir.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load SMILES Test Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "mdl_dir_df = pd.read_csv('./experiments/RECON/model_dir.csv')\n",
    "mdl_dir_df['recon_acc'] = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load SMILES test set\n",
    "exp_dir = mdl_dir_df['exp_dir'].iloc[0]\n",
    "smiles_test_file = os.path.join(exp_dir, 'smiles_test.smi')\n",
    "smiles_test_lst = []\n",
    "with open(smiles_test_file) as file:\n",
    "    lines = file.read().splitlines()\n",
    "    idx = 0\n",
    "    for l in lines:\n",
    "        # convert to tokens\n",
    "        if len(l.split(\" \")) == 1: # the SMILES comes from ZINC 250k\n",
    "            smi = l # remove /n\n",
    "            id = 'zinc_' + str(idx) # use zinc + idx as instance ID\n",
    "            idx += 1\n",
    "        else: # the SMILES comes from FDA drug\n",
    "            smi = l.split(\" \")[0]\n",
    "            id = l.split(\" \")[1].lower() # use FDA drug name as instance ID\n",
    "            idx += 1\n",
    "        smiles_test_lst.append(smi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluate Molecule Reconstruction Accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## latent size of 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = mdl_dir_df['latent_size'] == 64\n",
    "sub_df = mdl_dir_df[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "latent_size                           64\n",
      "manifold                       Euclidean\n",
      "exp_dir        ./experiments/KDD/kdd_009\n",
      "checkpoint     checkpoint_epoch110.model\n",
      "recon_acc                           None\n",
      "Name: 0, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7808f8c63f244477a2b0ebf80bd2b6ac",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.943\n",
      "------------------------------------\n",
      "latent_size                                   64\n",
      "manifold                               Euclidean\n",
      "exp_dir        ./experiments/KDD_SEED/kdd_e64_s1\n",
      "checkpoint             checkpoint_epoch120.model\n",
      "recon_acc                                   None\n",
      "Name: 1, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5efcf170b76d4b8a95142d6ab51f4a22",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.97\n",
      "------------------------------------\n",
      "latent_size                                   64\n",
      "manifold                               Euclidean\n",
      "exp_dir        ./experiments/KDD_SEED/kdd_e64_s2\n",
      "checkpoint             checkpoint_epoch120.model\n",
      "recon_acc                                   None\n",
      "Name: 2, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1e03779699c2424da583dccaf89492b6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.941\n",
      "------------------------------------\n",
      "latent_size                           64\n",
      "manifold                         Lorentz\n",
      "exp_dir        ./experiments/KDD/kdd_010\n",
      "checkpoint     checkpoint_epoch110.model\n",
      "recon_acc                           None\n",
      "Name: 3, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "228a751852dc4da48483fb1ec6960fb9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.93\n",
      "------------------------------------\n",
      "latent_size                                   64\n",
      "manifold                                 Lorentz\n",
      "exp_dir        ./experiments/KDD_SEED/kdd_l64_s1\n",
      "checkpoint             checkpoint_epoch120.model\n",
      "recon_acc                                   None\n",
      "Name: 4, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "93e5b159fa9943c78b816c5c600ac432",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.954\n",
      "------------------------------------\n",
      "latent_size                                   64\n",
      "manifold                                 Lorentz\n",
      "exp_dir        ./experiments/KDD_SEED/kdd_l64_s2\n",
      "checkpoint             checkpoint_epoch120.model\n",
      "recon_acc                                   None\n",
      "Name: 5, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd21ac56b9b84ba999777b1af3b46ad5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.97\n",
      "------------------------------------\n"
     ]
    }
   ],
   "source": [
    "for idx, row in sub_df.iterrows():\n",
    "    print(row)\n",
    "    exp_dir = row['exp_dir']\n",
    "    checkpoint = row['checkpoint']\n",
    "    config_path = os.path.join(exp_dir, 'configs.json')\n",
    "    checkpoint_path = os.path.join(exp_dir, checkpoint)\n",
    "    with open(config_path, 'r') as fp:\n",
    "        configs = json.load(fp)\n",
    "    fp.close()\n",
    "    configs['checkpoint'] = checkpoint\n",
    "    model = load_model(configs)\n",
    "    smiles_sample_lst = random.sample(smiles_test_lst, 1000)\n",
    "    recon_score = recon_acc_score(configs, model, smiles_sample_lst)\n",
    "    mdl_dir_df['recon_acc'].iloc[idx] = recon_score\n",
    "    print('Recon. accuracy: ' + str(recon_score))\n",
    "    print('------------------------------------')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## latent size of 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = mdl_dir_df['latent_size'] == 32\n",
    "sub_df = mdl_dir_df[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "latent_size                           32\n",
      "manifold                       Euclidean\n",
      "exp_dir        ./experiments/KDD/kdd_015\n",
      "checkpoint     checkpoint_epoch110.model\n",
      "recon_acc                           None\n",
      "Name: 6, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d82a6c80b194d1494a426113abaf3dd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.865\n",
      "------------------------------------\n",
      "latent_size                                   32\n",
      "manifold                               Euclidean\n",
      "exp_dir        ./experiments/KDD_SEED/kdd_e32_s1\n",
      "checkpoint             checkpoint_epoch130.model\n",
      "recon_acc                                   None\n",
      "Name: 7, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5406807c41784d569d3d513d151c9edd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.911\n",
      "------------------------------------\n",
      "latent_size                                   32\n",
      "manifold                               Euclidean\n",
      "exp_dir        ./experiments/KDD_SEED/kdd_e32_s2\n",
      "checkpoint             checkpoint_epoch110.model\n",
      "recon_acc                                   None\n",
      "Name: 8, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c331d1d165d146cb953766cabf9730f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.937\n",
      "------------------------------------\n",
      "latent_size                           32\n",
      "manifold                         Lorentz\n",
      "exp_dir        ./experiments/KDD/kdd_016\n",
      "checkpoint     checkpoint_epoch110.model\n",
      "recon_acc                           None\n",
      "Name: 9, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "50f5ba6da3554ae0b407b0118564fdf3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.881\n",
      "------------------------------------\n",
      "latent_size                                   32\n",
      "manifold                                 Lorentz\n",
      "exp_dir        ./experiments/KDD_SEED/kdd_l32_s1\n",
      "checkpoint             checkpoint_epoch130.model\n",
      "recon_acc                                   None\n",
      "Name: 10, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a6f7ddd3186d4017a9f32d3a800dc5d5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.948\n",
      "------------------------------------\n",
      "latent_size                                   32\n",
      "manifold                                 Lorentz\n",
      "exp_dir        ./experiments/KDD_SEED/kdd_l32_s2\n",
      "checkpoint             checkpoint_epoch110.model\n",
      "recon_acc                                   None\n",
      "Name: 11, dtype: object\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "add7f5abb0444c3786656f9b618897c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.89\n",
      "------------------------------------\n"
     ]
    }
   ],
   "source": [
    "for idx, row in sub_df.iterrows():\n",
    "    print(row)\n",
    "    exp_dir = row['exp_dir']\n",
    "    checkpoint = row['checkpoint']\n",
    "    config_path = os.path.join(exp_dir, 'configs.json')\n",
    "    checkpoint_path = os.path.join(exp_dir, checkpoint)\n",
    "    with open(config_path, 'r') as fp:\n",
    "        configs = json.load(fp)\n",
    "    fp.close()\n",
    "    configs['checkpoint'] = checkpoint\n",
    "    model = load_model(configs)\n",
    "    smiles_sample_lst = random.sample(smiles_test_lst, 1000)\n",
    "    recon_score = recon_acc_score(configs, model, smiles_sample_lst)\n",
    "    mdl_dir_df['recon_acc'].iloc[idx] = recon_score\n",
    "    print('Recon. accuracy: ' + str(recon_score))\n",
    "    print('------------------------------------')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Structure Only Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Euclidean Space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "426f8f572e854ddb97a999b6b266be63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.905\n",
      "------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "738b81e222c3404ab433158d1b0c7155",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.899\n",
      "------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e04971f926c34f5ba2fd575e31fd22b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.894\n",
      "------------------------------------\n"
     ]
    }
   ],
   "source": [
    "seed_lst = [0, 1, 2]\n",
    "exp_dir = './experiments/EXP_TASK/exp_task_009'\n",
    "checkpoint = 'checkpoint_epoch100.model'\n",
    "config_path = os.path.join(exp_dir, 'configs.json')\n",
    "checkpoint_path = os.path.join(exp_dir, checkpoint)\n",
    "with open(config_path, 'r') as fp:\n",
    "    configs = json.load(fp)\n",
    "fp.close()\n",
    "configs['checkpoint'] = checkpoint\n",
    "model = load_model(configs)\n",
    "for s in seed_lst:\n",
    "    random.seed(s)\n",
    "    smiles_sample_lst = random.sample(smiles_test_lst, 1000)\n",
    "    recon_score = recon_acc_score(configs, model, smiles_sample_lst)\n",
    "    mdl_dir_df['recon_acc'].iloc[idx] = recon_score\n",
    "    print('Recon. accuracy: ' + str(recon_score))\n",
    "    print('------------------------------------')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Lorentz Space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "94ca60f020eb4c91b261eacc80decbb7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.885\n",
      "------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "16c6528aa2cc45beaf936a74c24ce4ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.854\n",
      "------------------------------------\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d4319ecf34d84190b8250753a4a8e923",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Recon. accuracy: 0.873\n",
      "------------------------------------\n"
     ]
    }
   ],
   "source": [
    "seed_lst = [0, 1, 2]\n",
    "exp_dir = './experiments/EXP_TASK/exp_task_010'\n",
    "checkpoint = 'checkpoint_epoch075.model'\n",
    "config_path = os.path.join(exp_dir, 'configs.json')\n",
    "checkpoint_path = os.path.join(exp_dir, checkpoint)\n",
    "with open(config_path, 'r') as fp:\n",
    "    configs = json.load(fp)\n",
    "fp.close()\n",
    "configs['checkpoint'] = checkpoint\n",
    "model = load_model(configs)\n",
    "for s in seed_lst:\n",
    "    random.seed(s)\n",
    "    smiles_sample_lst = random.sample(smiles_test_lst, 1000)\n",
    "    recon_score = recon_acc_score(configs, model, smiles_sample_lst)\n",
    "    mdl_dir_df['recon_acc'].iloc[idx] = recon_score\n",
    "    print('Recon. accuracy: ' + str(recon_score))\n",
    "    print('------------------------------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8706666666666667"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array([0.885, 0.854, 0.873]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.01276279314605111"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array([0.885, 0.854, 0.873]).std()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MolEnv",
   "language": "python",
   "name": "molenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "193px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
